{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mengfanxu/miniconda3/envs/transmla/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from qwen2.modeling_qwen2 import Qwen2ForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "from copy import deepcopy\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.45s/it]\n",
      "Some weights of Qwen2ForCausalLM were not initialized from the model checkpoint at /data2/mengfanxu/huggingface/Qwen2.5-3B and are newly initialized: ['model.layers.0.self_attn.k_up_proj.weight', 'model.layers.0.self_attn.v_up_proj.weight', 'model.layers.1.self_attn.k_up_proj.weight', 'model.layers.1.self_attn.v_up_proj.weight', 'model.layers.10.self_attn.k_up_proj.weight', 'model.layers.10.self_attn.v_up_proj.weight', 'model.layers.11.self_attn.k_up_proj.weight', 'model.layers.11.self_attn.v_up_proj.weight', 'model.layers.12.self_attn.k_up_proj.weight', 'model.layers.12.self_attn.v_up_proj.weight', 'model.layers.13.self_attn.k_up_proj.weight', 'model.layers.13.self_attn.v_up_proj.weight', 'model.layers.14.self_attn.k_up_proj.weight', 'model.layers.14.self_attn.v_up_proj.weight', 'model.layers.15.self_attn.k_up_proj.weight', 'model.layers.15.self_attn.v_up_proj.weight', 'model.layers.16.self_attn.k_up_proj.weight', 'model.layers.16.self_attn.v_up_proj.weight', 'model.layers.17.self_attn.k_up_proj.weight', 'model.layers.17.self_attn.v_up_proj.weight', 'model.layers.18.self_attn.k_up_proj.weight', 'model.layers.18.self_attn.v_up_proj.weight', 'model.layers.19.self_attn.k_up_proj.weight', 'model.layers.19.self_attn.v_up_proj.weight', 'model.layers.2.self_attn.k_up_proj.weight', 'model.layers.2.self_attn.v_up_proj.weight', 'model.layers.20.self_attn.k_up_proj.weight', 'model.layers.20.self_attn.v_up_proj.weight', 'model.layers.21.self_attn.k_up_proj.weight', 'model.layers.21.self_attn.v_up_proj.weight', 'model.layers.22.self_attn.k_up_proj.weight', 'model.layers.22.self_attn.v_up_proj.weight', 'model.layers.23.self_attn.k_up_proj.weight', 'model.layers.23.self_attn.v_up_proj.weight', 'model.layers.24.self_attn.k_up_proj.weight', 'model.layers.24.self_attn.v_up_proj.weight', 'model.layers.25.self_attn.k_up_proj.weight', 'model.layers.25.self_attn.v_up_proj.weight', 'model.layers.26.self_attn.k_up_proj.weight', 'model.layers.26.self_attn.v_up_proj.weight', 'model.layers.27.self_attn.k_up_proj.weight', 'model.layers.27.self_attn.v_up_proj.weight', 'model.layers.28.self_attn.k_up_proj.weight', 'model.layers.28.self_attn.v_up_proj.weight', 'model.layers.29.self_attn.k_up_proj.weight', 'model.layers.29.self_attn.v_up_proj.weight', 'model.layers.3.self_attn.k_up_proj.weight', 'model.layers.3.self_attn.v_up_proj.weight', 'model.layers.30.self_attn.k_up_proj.weight', 'model.layers.30.self_attn.v_up_proj.weight', 'model.layers.31.self_attn.k_up_proj.weight', 'model.layers.31.self_attn.v_up_proj.weight', 'model.layers.32.self_attn.k_up_proj.weight', 'model.layers.32.self_attn.v_up_proj.weight', 'model.layers.33.self_attn.k_up_proj.weight', 'model.layers.33.self_attn.v_up_proj.weight', 'model.layers.34.self_attn.k_up_proj.weight', 'model.layers.34.self_attn.v_up_proj.weight', 'model.layers.35.self_attn.k_up_proj.weight', 'model.layers.35.self_attn.v_up_proj.weight', 'model.layers.4.self_attn.k_up_proj.weight', 'model.layers.4.self_attn.v_up_proj.weight', 'model.layers.5.self_attn.k_up_proj.weight', 'model.layers.5.self_attn.v_up_proj.weight', 'model.layers.6.self_attn.k_up_proj.weight', 'model.layers.6.self_attn.v_up_proj.weight', 'model.layers.7.self_attn.k_up_proj.weight', 'model.layers.7.self_attn.v_up_proj.weight', 'model.layers.8.self_attn.k_up_proj.weight', 'model.layers.8.self_attn.v_up_proj.weight', 'model.layers.9.self_attn.k_up_proj.weight', 'model.layers.9.self_attn.v_up_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Qwen2ForCausalLM(\n",
       "  (model): Qwen2Model(\n",
       "    (embed_tokens): Embedding(151936, 2048)\n",
       "    (layers): ModuleList(\n",
       "      (0-35): 36 x Qwen2DecoderLayer(\n",
       "        (self_attn): Qwen2SdpaMLAttention(\n",
       "          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)\n",
       "          (k_proj): Linear(in_features=2048, out_features=256, bias=True)\n",
       "          (k_up_proj): Linear(in_features=256, out_features=2048, bias=False)\n",
       "          (v_proj): Linear(in_features=2048, out_features=256, bias=True)\n",
       "          (v_up_proj): Linear(in_features=256, out_features=2048, bias=False)\n",
       "          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n",
       "          (rotary_emb): Qwen2RotaryEmbedding()\n",
       "        )\n",
       "        (mlp): Qwen2MLP(\n",
       "          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)\n",
       "          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)\n",
       "          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
       "        (post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
       "      )\n",
       "    )\n",
       "    (norm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
       "    (rotary_emb): Qwen2RotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2048, out_features=151936, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Qwen2ForCausalLM.from_pretrained(\"/data2/mengfanxu/huggingface/Qwen2.5-3B\", attn_implementation=\"sdpa\", device_map='cuda:1')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/data2/mengfanxu/huggingface/Qwen2.5-3B\")\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size = model.config.hidden_size\n",
    "n_heads = model.config.num_attention_heads\n",
    "kv_heads = model.config.num_key_value_heads\n",
    "head_dim = model.config.hidden_size//model.config.num_attention_heads\n",
    "kv_groups = model.config.num_attention_heads // model.config.num_key_value_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Insert identity matrices\n",
    "for name,module in model.named_modules():\n",
    "    if 'k_up_proj' in name or \"v_up_proj\" in name:\n",
    "        module.weight.data = torch.stack([torch.eye(kv_heads*head_dim).reshape(kv_heads, head_dim, kv_heads*head_dim)]*kv_groups,dim=1).reshape(hidden_size, kv_heads*head_dim).contiguous().to(module.weight.data.device,module.weight.data.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['Give me a short introduction to large language model. A large language model (LLM) is a type of artificial intelligence (AI']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = model.generate(**tokenizer(\"Give me a short introduction to large language model.\",return_tensors=\"pt\").to(\"cuda:1\"), max_new_tokens=16)\n",
    "tokenizer.batch_decode(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "583it [00:48, 12.04it/s]\n"
     ]
    }
   ],
   "source": [
    "for name,module in tqdm(model.named_modules()):\n",
    "    if name.endswith(\"self_attn\"):\n",
    "        # Orthogonal k_proj and k_up_proj\n",
    "        k_up_weight = deepcopy(module.k_up_proj.weight.data).reshape(hidden_size, kv_heads, head_dim) # (hidden_size, kv_heads, head_dim)\n",
    "        k_weight = deepcopy(module.k_proj.weight.data).reshape(kv_heads, head_dim, hidden_size) # (kv_heads, head_dim, hidden_size)\n",
    "        if module.k_proj.bias is not None:\n",
    "            k_weight = torch.cat([k_weight,deepcopy(module.k_proj.bias.data).reshape(kv_heads, head_dim,1)],dim=-1)\n",
    "        k_up_k = torch.einsum(\"Dhd,hdL->hDL\",k_up_weight, k_weight) # (kv_heads, hidden_size, hidden_size), rank<=head_dim\n",
    "        U,S,V = torch.svd_lowrank(k_up_k, head_dim, niter=head_dim) # U(kv_heads, hidden_size, head_dim), S(kv_heads, head_dim), V(kv_heads, hidden_size, head_dim)\n",
    "        US_sqrt = torch.einsum('hDd,hd->Dhd',U,torch.sqrt(S)) # (latent_dim, kv_heads, head_dim)\n",
    "        S_sqrtV = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (kv_heads, head_dim, latent_dim)\n",
    "        if module.k_proj.bias is not None:\n",
    "            module.k_proj.bias.data = S_sqrtV[:,:,-1].reshape(-1).contiguous()\n",
    "            S_sqrtV = S_sqrtV[:,:,:-1]\n",
    "        module.k_up_proj.weight.data = US_sqrt.reshape(hidden_size, kv_heads*head_dim).contiguous()\n",
    "        module.k_proj.weight.data = S_sqrtV.reshape(kv_heads*head_dim, hidden_size).contiguous()\n",
    "        # Orthogonal v_proj and v_up_proj\n",
    "        v_up_weight = deepcopy(module.v_up_proj.weight.data).reshape(hidden_size, kv_heads, head_dim) # (hidden_size, kv_heads, head_dim)\n",
    "        v_weight = deepcopy(module.v_proj.weight.data).reshape(kv_heads, head_dim, hidden_size) # (kv_heads, head_dim, hidden_size)\n",
    "        if module.v_proj.bias is not None:\n",
    "            v_weight = torch.cat([v_weight,deepcopy(module.v_proj.bias.data).reshape(kv_heads, head_dim,1)],dim=-1)\n",
    "        v_up_v = torch.einsum(\"Dhd,hdL->hDL\",v_up_weight, v_weight) # (kv_heads, hidden_size, hidden_size), rank<=head_dim\n",
    "        U,S,V = torch.svd_lowrank(v_up_v, head_dim, niter=head_dim) # U(kv_heads, hidden_size, head_dim), S(kv_heads, head_dim), V(kv_heads, hidden_size, head_dim)\n",
    "        US_sqrt = torch.einsum('hDd,hd->Dhd',U,torch.sqrt(S)) # (latent_dim, kv_heads, head_dim)\n",
    "        S_sqrtV = torch.einsum('hd,hLd->hdL',torch.sqrt(S),V) # (kv_heads, head_dim, latent_dim)\n",
    "        if module.v_proj.bias is not None:\n",
    "            module.v_proj.bias.data = S_sqrtV[:,:,-1].reshape(-1).contiguous()\n",
    "            S_sqrtV = S_sqrtV[:,:,:-1]\n",
    "        module.v_up_proj.weight.data = US_sqrt.reshape(hidden_size, kv_heads*head_dim).contiguous()\n",
    "        module.v_proj.weight.data = S_sqrtV.reshape(kv_heads*head_dim, hidden_size).contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['Give me a short introduction to large language model. A large language model (LLM) is a type of artificial intelligence (AI']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = model.generate(**tokenizer(\"Give me a short introduction to large language model.\",return_tensors=\"pt\").to(\"cuda:1\"), max_new_tokens=16)\n",
    "tokenizer.batch_decode(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"qwen2.5_3b_transMLA\")\n",
    "#model.push_to_hub(\"fxmeng/qwen2.5_3b_transMLA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('qwen2.5_3b_transMLA/tokenizer_config.json',\n",
       " 'qwen2.5_3b_transMLA/special_tokens_map.json',\n",
       " 'qwen2.5_3b_transMLA/vocab.json',\n",
       " 'qwen2.5_3b_transMLA/merges.txt',\n",
       " 'qwen2.5_3b_transMLA/added_tokens.json',\n",
       " 'qwen2.5_3b_transMLA/tokenizer.json')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(\"qwen2.5_3b_transMLA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Qwen2ForCausalLM.from_pretrained(\"fxmeng/qwen2.5_0.5b_instruct_transMLA\", attn_implementation=\"eager\", device_map='cuda:0')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"fxmeng/qwen2.5_0.5b_instruct_transMLA\")\n",
    "output = model.generate(**tokenizer(\"Give me a short introduction to large language model.\",return_tensors=\"pt\").to(\"cuda:0\"), max_new_tokens=16)\n",
    "tokenizer.batch_decode(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import ListedColormap\n",
    "cmap = ListedColormap([\"white\", \"blue\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7595880e8220>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGhCAYAAADbf0s2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhSElEQVR4nO3df2xV9f3H8dctpZcK3Ftbx729o9XOkQCCiBRqwWz7jmagDsdEHaRORAJTi/LDH9CZ4pxiETfnUIRpHJoIoiSCQiaGFQTZSoECTkQKxgY68BaV9V5+2FJ7P98/DFeuFijltvdzb5+P5CT2nHMP709S7yvv9z3n1mGMMQIAwEJJsS4AAIAzIaQAANYipAAA1iKkAADWIqQAANYipAAA1iKkAADWIqQAANYipAAA1iKkAADWillILViwQJdddpm6dOmivLw8bdmyJValAAAsFZOQev311zVjxgw98sgj2r59uwYMGKARI0bo8OHDsSgHAGApRyy+YDYvL0+DBw/Wc889J0kKhULKysrSvffeq1mzZp3z9aFQSIcOHVL37t3lcDjaulwAQJQZY3T06FH5fD4lJZ25X0pux5okSSdPnlRlZaWKi4vD+5KSklRQUKDy8vJmX9PQ0KCGhobwzwcPHlTfvn3bvFYAQNuqqalRz549z3i83UPqiy++UFNTkzweT8R+j8ejPXv2NPua0tJSPfroo80cqZHkUiAQ/ToBAG0nGAwqKytL3bt3P+t57R5SrVFcXKwZM2aEfz61uEDAJZfLpdMnfvx1LACIH+f6yKbdQ+qSSy5Rp06dVFtbG7G/trZWXq+32dc4nU45nc72KA8AYJF2v7svJSVFgwYNUllZWXhfKBRSWVmZ8vPz27scAIDFYjLumzFjhsaPH6/c3FwNGTJEzzzzjI4fP64JEya06nqnj/gY/QFA4ohJSP3mN7/R559/rtmzZ8vv9+uqq67SmjVrvnczBQCgY4vJc1IXKhgMyu12KxAIyOVyRRyjkwIA+53tffx0cXF33/lg9AcAiYMvmAUAWIuQAgBYK+HGfadrbvTH2A8A4gedFADAWgndSZ3uVAfFzRQAED/opAAA1iKkAADW6jDjvlN4jgoA4gedFADAWoQUAMBaHW7cdzpGfwBgNzopAIC1CCkAgLU69LjvdIz+AMA+dFIAAGsRUgAAazHuawajPwCwA50UAMBahBQAwFqM+86B0R8AxA6dFADAWoQUAMBajPvOA6M/AGhfdFIAAGsRUgAAazHuayVGfwDQ9uikAADWIqQAANZi3BcFzY3+GPsBwIWjkwIAWItOKspOdVDcTAEAF45OCgBgLUIKAGAtxn1thOeoAODC0UkBAKxFSAEArMW4rx0w+gOA1qGTAgBYi5ACAFiLcV87Y/QHAC1HJwUAsBYhBQCwFuO+GGL0BwBnRycFALAWIQUAsBbjPksw+gOA74t6J1VaWqrBgwere/fu6tGjh0aPHq2qqqqIc+rr61VUVKSMjAx169ZNY8aMUW1tbbRLAQDEuaiH1IYNG1RUVKTNmzdr7dq1amxs1C9+8QsdP348fM706dO1atUqLV++XBs2bNChQ4d00003RbsUAECccxjTtgOlzz//XD169NCGDRv0k5/8RIFAQD/4wQ+0dOlS3XzzzZKkPXv2qE+fPiovL9c111xzzmsGg0G53W4FAgG5XK62LD/mGP0BSEQtfR9v8xsnAoGAJCk9PV2SVFlZqcbGRhUUFITP6d27t7Kzs1VeXt7sNRoaGhQMBiM2AEDia9OQCoVCmjZtmoYNG6Z+/fpJkvx+v1JSUpSWlhZxrsfjkd/vb/Y6paWlcrvd4S0rK6stywYAWKJNQ6qoqEi7du3SsmXLLug6xcXFCgQC4a2mpiZKFdrPmG83h+PbDQA6gja7BX3KlClavXq1Nm7cqJ49e4b3e71enTx5UnV1dRHdVG1trbxeb7PXcjqdcjqdbVUqAMBSUe+kjDGaMmWKVqxYoXXr1iknJyfi+KBBg9S5c2eVlZWF91VVVenAgQPKz8+PdjkAgDgW9U6qqKhIS5cu1VtvvaXu3buHP2dyu91KTU2V2+3WxIkTNWPGDKWnp8vlcunee+9Vfn5+i+7s68iae+CXO/4AJLKo34LuOMMHJosXL9Ydd9wh6ZuHee+//3699tpramho0IgRI/T888+fcdz3XR3pFvQzIaQAxLOWvo+3+XNSbYGQ+hbPUQGIR9Y8JwUAQGsRUgAAa/Et6HGOb08HkMjopAAA1iKkAADWYtyXQBj9AUg0dFIAAGsRUgAAazHuS1CM/gAkAjopAIC1CCkAgLUY93UAjP4AxCs6KQCAtQgpAIC1GPd1MIz+AMQTOikAgLUIKQCAtRj3dWCM/gDYjk4KAGAtQgoAYC3GfZDE6A+AneikAADWIqQAANZi3IfvaW70x9gPQCzQSQEArEUnhbM61UFxMwWAWKCTAgBYi5ACAFiLcR9ahBEfgFigkwIAWIuQAgBYi5ACAFiLkAIAWIuQAgBYi5ACAFiLkAIAWIuQAgBYi5ACAFiLkAIAWIuQAgBYi5ACAFiLkAIAWIuQAgBYi5ACAFiLkAIAWIuQAgBYq81Dau7cuXI4HJo2bVp4X319vYqKipSRkaFu3bppzJgxqq2tbetSAABxpk1DauvWrfrb3/6mK6+8MmL/9OnTtWrVKi1fvlwbNmzQoUOHdNNNN7VlKQCAONRmIXXs2DEVFhbqxRdf1MUXXxzeHwgE9NJLL+npp5/Wz3/+cw0aNEiLFy/Wv//9b23evLmtygEAxKE2C6mioiLdcMMNKigoiNhfWVmpxsbGiP29e/dWdna2ysvL26ocAEAcSm6Liy5btkzbt2/X1q1bv3fM7/crJSVFaWlpEfs9Ho/8fn+z12toaFBDQ0P452AwGNV6AQB2inonVVNTo6lTp2rJkiXq0qVLVK5ZWloqt9sd3rKysqJyXQCA3aIeUpWVlTp8+LCuvvpqJScnKzk5WRs2bND8+fOVnJwsj8ejkydPqq6uLuJ1tbW18nq9zV6zuLhYgUAgvNXU1ES7bACAhaI+7hs+fLg+/PDDiH0TJkxQ7969NXPmTGVlZalz584qKyvTmDFjJElVVVU6cOCA8vPzm72m0+mU0+mMdqkAAMtFPaS6d++ufv36Rezr2rWrMjIywvsnTpyoGTNmKD09XS6XS/fee6/y8/N1zTXXRLscAEAca5MbJ87lL3/5i5KSkjRmzBg1NDRoxIgRev7552NRCgDAYg5jjIl1EecrGAzK7XYrEAjI5XLFuhwAwHlq6fs4390HALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsBYhBQCwFiEFALAWIQUAsFabhNTBgwd12223KSMjQ6mpqerfv7+2bdsWPm6M0ezZs5WZmanU1FQVFBRo3759bVEKACCORT2k/ve//2nYsGHq3Lmz3nnnHe3evVt//vOfdfHFF4fPmTdvnubPn69FixapoqJCXbt21YgRI1RfXx/tcgAAccxhjDHRvOCsWbP0r3/9S++//36zx40x8vl8uv/++/XAAw9IkgKBgDwej15++WWNHTv2nP9GMBiU2+1WIBCQy+WKZvkAgHbQ0vfxqHdSb7/9tnJzc3XLLbeoR48eGjhwoF588cXw8erqavn9fhUUFIT3ud1u5eXlqby8vNlrNjQ0KBgMRmwAgMQX9ZD69NNPtXDhQvXq1Uvvvvuu7r77bt1333165ZVXJEl+v1+S5PF4Il7n8XjCx76rtLRUbrc7vGVlZUW7bACAhaIeUqFQSFdffbWeeOIJDRw4UJMnT9akSZO0aNGiVl+zuLhYgUAgvNXU1ESxYgCAraIeUpmZmerbt2/Evj59+ujAgQOSJK/XK0mqra2NOKe2tjZ87LucTqdcLlfEBgBIfFEPqWHDhqmqqipi3969e3XppZdKknJycuT1elVWVhY+HgwGVVFRofz8/GiXAwCIY8nRvuD06dM1dOhQPfHEE7r11lu1ZcsWvfDCC3rhhRckSQ6HQ9OmTdPjjz+uXr16KScnRyUlJfL5fBo9enS0ywEAxLGoh9TgwYO1YsUKFRcX649//KNycnL0zDPPqLCwMHzOQw89pOPHj2vy5Mmqq6vTtddeqzVr1qhLly7RLgcAEMei/pxUe+A5KQCIbzF7TgoAgGghpAAA1iKkAADWivqNEwCiw+H49r/j75NjIDropAAA1iKkAADWYtwHWOr0ER+jP3RUdFIAAGsRUgAAazHuA+IAoz90VHRSAABrEVIAAGsx7gPiDKM/dCR0UgAAaxFSAABrMe4D4lhzoz/GfkgkdFIAAGvRSQEJ4lQHxc0USCR0UgAAaxFSAABrMe4DEgzPUSGR0EkBAKxFSAEArMW4D0hgjP4Q7+ikAADWIqQAANZi3Ad0EIz+EI/opAAA1iKkAADWYtwHdECM/hAv6KQAANYipAAA1mLcB3RwjP5gMzopAIC1CCkAgLUY9wEIY/QH29BJAQCsRUgBAKzFuA9Asxj9wQZ0UgAAaxFSAABrMe4DcE7Njf4Y+6E90EkBAKxFJwXgvJzqoLiZAu2BTgoAYC1CCgBgLcZ9AFqF56jQHqLeSTU1NamkpEQ5OTlKTU3V5Zdfrscee0zmtN9cY4xmz56tzMxMpaamqqCgQPv27Yt2KQCAOBf1kHryySe1cOFCPffcc/r444/15JNPat68eXr22WfD58ybN0/z58/XokWLVFFRoa5du2rEiBGqr6+PdjkAgDjmMCa6zfkvf/lLeTwevfTSS+F9Y8aMUWpqql599VUZY+Tz+XT//ffrgQcekCQFAgF5PB69/PLLGjt27Dn/jWAwKLfbrUAgIJfLFc3yAVwgRn9oiZa+j0e9kxo6dKjKysq0d+9eSdIHH3ygTZs26brrrpMkVVdXy+/3q6CgIPwat9utvLw8lZeXN3vNhoYGBYPBiA0AkPiifuPErFmzFAwG1bt3b3Xq1ElNTU2aM2eOCgsLJUl+v1+S5PF4Il7n8XjCx76rtLRUjz76aLRLBQBYLuqd1BtvvKElS5Zo6dKl2r59u1555RX96U9/0iuvvNLqaxYXFysQCIS3mpqaKFYMIJqM+XZzOL7dgNaIeif14IMPatasWeHPlvr376/9+/ertLRU48ePl9frlSTV1tYqMzMz/Lra2lpdddVVzV7T6XTK6XRGu1QAgOWi3kmdOHFCSUmRl+3UqZNCoZAkKScnR16vV2VlZeHjwWBQFRUVys/Pj3Y5AIA4FvVOatSoUZozZ46ys7N1xRVXaMeOHXr66ad15513SpIcDoemTZumxx9/XL169VJOTo5KSkrk8/k0evToaJcDIIZ44BcXKuoh9eyzz6qkpET33HOPDh8+LJ/Pp9/97neaPXt2+JyHHnpIx48f1+TJk1VXV6drr71Wa9asUZcuXaJdDgAgjkX9Oan2wHNSQPyhk8LpWvo+znf3AWgXjP7QGnwLOgDAWoQUAMBajPsAtDtGf2gpOikAgLUIKQCAtRj3AYgpRn84GzopAIC1CCkAgLUY9wGwRnOjP8Z+HRudFADAWnRSAKx0qoPiZoqOjU4KAGAtQgoAYC3GfQCsxnNUHRudFADAWoQUAMBajPsAxA1Gfx0PnRQAwFqEFADAWoz7AMQlRn8dA50UAMBahBQAwFqM+wDEPUZ/iYtOCgBgLUIKAGAtxn0AEgqjv8RCJwUAsBYhBQCwFuM+AAmL0V/8o5MCAFiLkAIAWItxH4AOgdFffKKTAgBYi5ACAFiLcR+ADqe50R9jPzvRSQEArEUnBaBDO9VBcTOFneikAADWIqQAANZi3AcA4jkqW9FJAQCsRUgBAKzFuA8AvoPRnz3opAAA1iKkAADWYtwHAGfB6C+2zruT2rhxo0aNGiWfzyeHw6GVK1dGHDfGaPbs2crMzFRqaqoKCgq0b9++iHOOHDmiwsJCuVwupaWlaeLEiTp27NgFLQQAkHjOO6SOHz+uAQMGaMGCBc0enzdvnubPn69FixapoqJCXbt21YgRI1RfXx8+p7CwUB999JHWrl2r1atXa+PGjZo8eXLrVwEASEzmAkgyK1asCP8cCoWM1+s1Tz31VHhfXV2dcTqd5rXXXjPGGLN7924jyWzdujV8zjvvvGMcDoc5ePBgi/7dQCBgJJlAIHAh5QNAq30z8Ptmw/lr6ft4VG+cqK6ult/vV0FBQXif2+1WXl6eysvLJUnl5eVKS0tTbm5u+JyCggIlJSWpoqKi2es2NDQoGAxGbACAxBfVkPL7/ZIkj8cTsd/j8YSP+f1+9ejRI+J4cnKy0tPTw+d8V2lpqdxud3jLysqKZtkAAEvFxS3oxcXFCgQC4a2mpibWJQHo4E4f+Dkc326IrqiGlNfrlSTV1tZG7K+trQ0f83q9Onz4cMTxr7/+WkeOHAmf811Op1MulytiAwAkvqiGVE5Ojrxer8rKysL7gsGgKioqlJ+fL0nKz89XXV2dKisrw+esW7dOoVBIeXl50SwHABDnzvth3mPHjumTTz4J/1xdXa2dO3cqPT1d2dnZmjZtmh5//HH16tVLOTk5Kikpkc/n0+jRoyVJffr00ciRIzVp0iQtWrRIjY2NmjJlisaOHSufzxe1hQFAe+GB3zZ0vrcNrl+/3kj63jZ+/HhjzDe3oZeUlBiPx2OcTqcZPny4qaqqirjGl19+acaNG2e6detmXC6XmTBhgjl69GjUb10EgPbGrekt09L3cYcx8Zf1wWBQbrdbgUCAz6cAWIVOqmVa+j7Od/cBQBQx+ouuuLgFHQDQMRFSAABrMe4DgDbC6O/C0UkBAKxFSAEArMW4DwDaQXOjP8Z+50YnBQCwFp0UALSzUx0UN1OcG50UAMBahBQAwFqM+wAgRniO6tzopAAA1iKkAADWYtwHABZg9Nc8OikAgLUIKQCAtRj3AYBlGP19i04KAGAtQgoAYC3GfQBgsY4++qOTAgBYi5ACAFiLcR8AxImOOPqjkwIAWIuQAgBYi3EfAMShjjL6o5MCAFiLkAIAWItxHwDEuUQe/dFJAQCsRUgBAKzFuA8AEkhzo794HvvRSQEArEUnBQAJ6lQHFc83U9BJAQCsRUgBAKzFuA8AElw8P0dFJwUAsBYhBQCwFuM+AOhA4m30RycFALAWIQUAsBbjPgDooOJh9EcnBQCwFiEFALDWeYfUxo0bNWrUKPl8PjkcDq1cuTJ8rLGxUTNnzlT//v3VtWtX+Xw+3X777Tp06FDENY4cOaLCwkK5XC6lpaVp4sSJOnbs2AUvBgDQOsZ8u9nkvEPq+PHjGjBggBYsWPC9YydOnND27dtVUlKi7du3680331RVVZVuvPHGiPMKCwv10Ucfae3atVq9erU2btyoyZMnt34VAICE5DCm9bnpcDi0YsUKjR49+oznbN26VUOGDNH+/fuVnZ2tjz/+WH379tXWrVuVm5srSVqzZo2uv/56/fe//5XP5zvnvxsMBuV2uxUIBORyuVpbPgAgRlr6Pt7mn0kFAgE5HA6lpaVJksrLy5WWlhYOKEkqKChQUlKSKioq2rocAEAcadNb0Ovr6zVz5kyNGzcunJR+v189evSILCI5Wenp6fL7/c1ep6GhQQ0NDeGfg8Fg2xUNALBGm3VSjY2NuvXWW2WM0cKFCy/oWqWlpXK73eEtKysrSlUCAGzWJiF1KqD279+vtWvXRswbvV6vDh8+HHH+119/rSNHjsjr9TZ7veLiYgUCgfBWU1PTFmUDACwT9XHfqYDat2+f1q9fr4yMjIjj+fn5qqurU2VlpQYNGiRJWrdunUKhkPLy8pq9ptPplNPpjHapAADLnXdIHTt2TJ988kn45+rqau3cuVPp6enKzMzUzTffrO3bt2v16tVqamoKf86Unp6ulJQU9enTRyNHjtSkSZO0aNEiNTY2asqUKRo7dmyL7uwDAHQc530L+nvvvaf/+7//+97+8ePH6w9/+INycnKafd369ev1s5/9TNI3D/NOmTJFq1atUlJSksaMGaP58+erW7duLaqBW9ABIL619H38gp6TihVCCgDimzXPSQEA0FqEFADAWoQUAMBahBQAwFqEFADAWoQUAMBahBQAwFqEFADAWoQUAMBahBQAwFqEFADAWoQUAMBahBQAwFqEFADAWlH/y7zt4dRfFwkGgzGuBADQGqfev8/116LiMqSOHj0qScrKyopxJQCAC3H06FG53e4zHo/LP3oYCoV06NAhGWOUnZ2tmpqahP3jh8FgUFlZWQm9Rol1JpqOsM6OsEap7dZpjNHRo0fl8/mUlHTmT57ispNKSkpSz549w+2iy+VK6F8SqWOsUWKdiaYjrLMjrFFqm3WerYM6hRsnAADWIqQAANaK65ByOp165JFH5HQ6Y11Km+kIa5RYZ6LpCOvsCGuUYr/OuLxxAgDQMcR1JwUASGyEFADAWoQUAMBahBQAwFpxG1ILFizQZZddpi5duigvL09btmyJdUkXpLS0VIMHD1b37t3Vo0cPjR49WlVVVRHn1NfXq6ioSBkZGerWrZvGjBmj2traGFV84ebOnSuHw6Fp06aF9yXKGg8ePKjbbrtNGRkZSk1NVf/+/bVt27bwcWOMZs+erczMTKWmpqqgoED79u2LYcXnr6mpSSUlJcrJyVFqaqouv/xyPfbYYxHfxRaP69y4caNGjRoln88nh8OhlStXRhxvyZqOHDmiwsJCuVwupaWlaeLEiTp27Fg7ruLszrbGxsZGzZw5U/3791fXrl3l8/l0++2369ChQxHXaLc1mji0bNkyk5KSYv7+97+bjz76yEyaNMmkpaWZ2traWJfWaiNGjDCLFy82u3btMjt37jTXX3+9yc7ONseOHQufc9ddd5msrCxTVlZmtm3bZq655hozdOjQGFbdelu2bDGXXXaZufLKK83UqVPD+xNhjUeOHDGXXnqpueOOO0xFRYX59NNPzbvvvms++eST8Dlz5841brfbrFy50nzwwQfmxhtvNDk5Oearr76KYeXnZ86cOSYjI8OsXr3aVFdXm+XLl5tu3bqZv/71r+Fz4nGd//jHP8zDDz9s3nzzTSPJrFixIuJ4S9Y0cuRIM2DAALN582bz/vvvmx//+Mdm3Lhx7bySMzvbGuvq6kxBQYF5/fXXzZ49e0x5ebkZMmSIGTRoUMQ12muNcRlSQ4YMMUVFReGfm5qajM/nM6WlpTGsKroOHz5sJJkNGzYYY775xencubNZvnx5+JyPP/7YSDLl5eWxKrNVjh49anr16mXWrl1rfvrTn4ZDKlHWOHPmTHPttdee8XgoFDJer9c89dRT4X11dXXG6XSa1157rT1KjIobbrjB3HnnnRH7brrpJlNYWGiMSYx1fvcNvCVr2r17t5Fktm7dGj7nnXfeMQ6Hwxw8eLDdam+p5oL4u7Zs2WIkmf379xtj2neNcTfuO3nypCorK1VQUBDel5SUpIKCApWXl8ewsugKBAKSpPT0dElSZWWlGhsbI9bdu3dvZWdnx926i4qKdMMNN0SsRUqcNb799tvKzc3VLbfcoh49emjgwIF68cUXw8erq6vl9/sj1ul2u5WXlxdX6xw6dKjKysq0d+9eSdIHH3ygTZs26brrrpOUOOs8XUvWVF5errS0NOXm5obPKSgoUFJSkioqKtq95mgIBAJyOBxKS0uT1L5rjLsvmP3iiy/U1NQkj8cTsd/j8WjPnj0xqiq6QqGQpk2bpmHDhqlfv36SJL/fr5SUlPAvySkej0d+vz8GVbbOsmXLtH37dm3duvV7xxJljZ9++qkWLlyoGTNm6Pe//722bt2q++67TykpKRo/fnx4Lc39DsfTOmfNmqVgMKjevXurU6dOampq0pw5c1RYWChJCbPO07VkTX6/Xz169Ig4npycrPT09Lhcd319vWbOnKlx48aFv2C2PdcYdyHVERQVFWnXrl3atGlTrEuJqpqaGk2dOlVr165Vly5dYl1OmwmFQsrNzdUTTzwhSRo4cKB27dqlRYsWafz48TGuLnreeOMNLVmyREuXLtUVV1yhnTt3atq0afL5fAm1zo6ssbFRt956q4wxWrhwYUxqiLtx3yWXXKJOnTp9746v2tpaeb3eGFUVPVOmTNHq1au1fv169ezZM7zf6/Xq5MmTqqurizg/ntZdWVmpw4cP6+qrr1ZycrKSk5O1YcMGzZ8/X8nJyfJ4PHG/RknKzMxU3759I/b16dNHBw4ckKTwWuL9d/jBBx/UrFmzNHbsWPXv31+//e1vNX36dJWWlkpKnHWeriVr8nq9Onz4cMTxr7/+WkeOHImrdZ8KqP3792vt2rURf6ajPdcYdyGVkpKiQYMGqaysLLwvFAqprKxM+fn5MazswhhjNGXKFK1YsULr1q1TTk5OxPFBgwapc+fOEeuuqqrSgQMH4mbdw4cP14cffqidO3eGt9zcXBUWFob/O97XKEnDhg373uMDe/fu1aWXXipJysnJkdfrjVhnMBhURUVFXK3zxIkT3/tjdZ06dVIoFJKUOOs8XUvWlJ+fr7q6OlVWVobPWbdunUKhkPLy8tq95tY4FVD79u3TP//5T2VkZEQcb9c1RvU2jHaybNky43Q6zcsvv2x2795tJk+ebNLS0ozf7491aa129913G7fbbd577z3z2WefhbcTJ06Ez7nrrrtMdna2Wbdundm2bZvJz883+fn5Maz6wp1+d58xibHGLVu2mOTkZDNnzhyzb98+s2TJEnPRRReZV199NXzO3LlzTVpamnnrrbfMf/7zH/OrX/3K+luzv2v8+PHmhz/8YfgW9DfffNNccskl5qGHHgqfE4/rPHr0qNmxY4fZsWOHkWSefvpps2PHjvCdbS1Z08iRI83AgQNNRUWF2bRpk+nVq5dVt6CfbY0nT540N954o+nZs6fZuXNnxPtRQ0ND+Brttca4DCljjHn22WdNdna2SUlJMUOGDDGbN2+OdUkXRFKz2+LFi8PnfPXVV+aee+4xF198sbnooovMr3/9a/PZZ5/Frugo+G5IJcoaV61aZfr162ecTqfp3bu3eeGFFyKOh0IhU1JSYjwej3E6nWb48OGmqqoqRtW2TjAYNFOnTjXZ2dmmS5cu5kc/+pF5+OGHI97I4nGd69evb/b/xfHjxxtjWramL7/80owbN85069bNuFwuM2HCBHP06NEYrKZ5Z1tjdXX1Gd+P1q9fH75Ge62RP9UBALBW3H0mBQDoOAgpAIC1CCkAgLUIKQCAtQgpAIC1CCkAgLUIKQCAtQgpAIC1CCkAgLUIKQCAtQgpAIC1CCkAgLX+H5T823VbqyxAAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "v_weight = model.model.layers[0].self_attn.v_proj.weight.data.to(\"cpu\")\n",
    "plt.imshow(v_weight@v_weight.T, cmap=cmap, interpolation='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x759587fe85b0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGhCAYAAADbf0s2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh2ElEQVR4nO3dfXBU5fnG8WtDyBKB3ZhYdrMl0dQyAwgi8hIDTttfyRTUYqmohYkVkYGqQQn4AqkTrFUMYmstilAdi84IoswIClNxaECQNgQIYEUk4MhACm5QaXZ5kRDZ5/eHw8JigBA22Wd3v5+ZM2POOTncz0y6d+9rz9l1GGOMAACwUEqsCwAA4GxoUgAAa9GkAADWokkBAKxFkwIAWIsmBQCwFk0KAGAtmhQAwFo0KQCAtWhSAABrxaxJzZkzR1dccYU6dOig/Px8bdiwIValAAAsFZMm9eabb2rKlCl67LHHtHnzZvXp00dDhw7VgQMHYlEOAMBSjlh8wGx+fr4GDBigF154QZIUCoWUk5Oj+++/X9OmTTvv74dCIe3fv1+dO3eWw+Fo7XIBAFFmjNGhQ4fk8/mUknL2eSm1DWuSJB0/flzV1dUqLS0N70tJSVFhYaEqKyub/J2GhgY1NDSEf963b5969uzZ6rUCAFpXbW2tunbtetbjbd6kvvrqK504cUIejydiv8fj0Y4dO5r8nfLycj3++ONNHKmV5FIgEP06AQCtJxgMKicnR507dz7neW3epFqitLRUU6ZMCf98cnGBgEsul0unJ358OxYAxI/zvWXT5k3qsssuU7t27VRXVxexv66uTl6vt8nfcTqdcjqdbVEeAMAibX53X1pamvr166eKiorwvlAopIqKChUUFLR1OQAAi8Uk7psyZYrGjBmj/v37a+DAgXruued05MgRjR07tkXXOz3iI/oDgMQRkyb1m9/8Rl9++aWmT58uv9+va665RitWrPjezRQAgOQWk+ekLlYwGJTb7VYgEJDL5Yo4xiQFAPY71+v46eLi7r4LQfQHAImDD5gFAFiLJgUAsFbCxX2nayr6I/YDgPjBJAUAsFZCT1KnOzlBcTMFAMQPJikAgLVoUgAAayVN3HcSz1EBQPxgkgIAWIsmBQCwVtLFfacj+gMAuzFJAQCsRZMCAFgrqeO+0xH9AYB9mKQAANaiSQEArEXc1wSiPwCwA5MUAMBaNCkAgLWI+86D6A8AYodJCgBgLZoUAMBaxH0XgOgPANoWkxQAwFo0KQCAtYj7WojoDwBaH5MUAMBaNCkAgLWI+6KgqeiP2A8ALh6TFADAWkxSUXZyguJmCgC4eExSAABr0aQAANYi7mslPEcFABePSQoAYC2aFADAWsR9bYDoDwBahkkKAGAtmhQAwFrEfW2M6A8Amo9JCgBgLZoUAMBaxH0xRPQHAOfGJAUAsBZNCgBgLeI+SxD9AcD3RX2SKi8v14ABA9S5c2d16dJFI0aMUE1NTcQ5x44dU3FxsbKystSpUyeNHDlSdXV10S4FABDnot6k1qxZo+LiYq1fv14rV65UY2OjfvGLX+jIkSPhcyZPnqxly5Zp8eLFWrNmjfbv369bbrkl2qUAAOKcw5jWDZS+/PJLdenSRWvWrNFPfvITBQIB/eAHP9DChQt16623SpJ27NihHj16qLKyUtddd915rxkMBuV2uxUIBORyuVqz/Jgj+gOQiJr7Ot7qN04EAgFJUmZmpiSpurpajY2NKiwsDJ/TvXt35ebmqrKysslrNDQ0KBgMRmwAgMTXqk0qFAqppKREgwcPVq9evSRJfr9faWlpysjIiDjX4/HI7/c3eZ3y8nK53e7wlpOT05plAwAs0apNqri4WNu2bdOiRYsu6jqlpaUKBALhrba2NkoV2s+YU5vDcWoDgGTQaregT5w4UcuXL9fatWvVtWvX8H6v16vjx4+rvr4+Ypqqq6uT1+tt8lpOp1NOp7O1SgUAWCrqk5QxRhMnTtSSJUu0atUq5eXlRRzv16+f2rdvr4qKivC+mpoa7d27VwUFBdEuBwAQx6I+SRUXF2vhwoV655131Llz5/D7TG63W+np6XK73Ro3bpymTJmizMxMuVwu3X///SooKGjWnX3JrKkHfrnjD0Aii/ot6I6zvGEyf/583XXXXZK+e5j3wQcf1BtvvKGGhgYNHTpUL7744lnjvjMl0y3oZ0OTAhDPmvs63urPSbUGmtQpPEcFIB5Z85wUAAAtRZMCAFiLT0GPc3x6OoBExiQFALAWTQoAYC3ivgRC9Acg0TBJAQCsRZMCAFiLuC9BEf0BSARMUgAAa9GkAADWIu5LAkR/AOIVkxQAwFo0KQCAtYj7kgzRH4B4wiQFALAWTQoAYC3iviRG9AfAdkxSAABr0aQAANYi7oMkoj8AdmKSAgBYiyYFALAWcR++p6noj9gPQCwwSQEArMUkhXM6OUFxMwWAWGCSAgBYiyYFALAWcR+aheeoAMQCkxQAwFo0KQCAtYj7cMGI/gC0FSYpAIC1aFIAAGsR9+GiEP0BaE1MUgAAa9GkAADWIu5D1BD9AYg2JikAgLVoUgAAaxH3oVUQ/QGIBiYpAIC1aFIAAGsR96HVEf0BaCkmKQCAtWhSAABrtXqTmjlzphwOh0pKSsL7jh07puLiYmVlZalTp04aOXKk6urqWrsUWMCYU5vDcWoDgKa0apPauHGj/va3v+nqq6+O2D958mQtW7ZMixcv1po1a7R//37dcsstrVkKACAOtVqTOnz4sIqKivTyyy/r0ksvDe8PBAJ65ZVX9Oyzz+rnP/+5+vXrp/nz5+vf//631q9f31rlAADiUKs1qeLiYt10000qLCyM2F9dXa3GxsaI/d27d1dubq4qKytbqxxYiOgPwPm0yi3oixYt0ubNm7Vx48bvHfP7/UpLS1NGRkbEfo/HI7/f3+T1Ghoa1NDQEP45GAxGtV4AgJ2iPknV1tZq0qRJWrBggTp06BCVa5aXl8vtdoe3nJycqFwXAGC3qDep6upqHThwQNdee61SU1OVmpqqNWvWaPbs2UpNTZXH49Hx48dVX18f8Xt1dXXyer1NXrO0tFSBQCC81dbWRrtsxFhT0R8ARD3uGzJkiD7++OOIfWPHjlX37t01depU5eTkqH379qqoqNDIkSMlSTU1Ndq7d68KCgqavKbT6ZTT6Yx2qQAAy0W9SXXu3Fm9evWK2NexY0dlZWWF948bN05TpkxRZmamXC6X7r//fhUUFOi6666LdjmIQyc/LomPUAIQk8/u+8tf/qKUlBSNHDlSDQ0NGjp0qF588cVYlAIAsJjDmPj7/6jBYFBut1uBQEAulyvW5aCVMEkBiau5r+N8CjqsxaenA+ADZgEA1qJJAQCsRdyHuED0ByQnJikAgLVoUgAAaxH3Ie4Q/QHJg0kKAGAtmhQAwFrEfYhrRH9AYmOSAgBYiyYFALAWcR8SBtEfkHiYpAAA1qJJAQCsRdyHhET0ByQGJikAgLVoUgAAaxH3IeER/QHxi0kKAGAtmhQAwFrEfUgqTUV/xH6AvZikAADWYpJC0jo5QXEzBWAvJikAgLVoUgAAaxH3IenxHBVgLyYpAIC1aFIAAGsR9wGnIfoD7MIkBQCwFk0KAGAt4j7gLIj4gNhjkgIAWIsmBQCwFk0KAGAtmhQAwFo0KQCAtWhSAABr0aQAANaiSQEArEWTAgBYiyYFALAWTQoAYC2aFADAWjQpAIC1aFIAAGu1SpPat2+f7rjjDmVlZSk9PV29e/fWpk2bwseNMZo+fbqys7OVnp6uwsJC7dq1qzVKAQDEsag3qf/9738aPHiw2rdvr/fee0/bt2/Xn//8Z1166aXhc2bNmqXZs2dr3rx5qqqqUseOHTV06FAdO3Ys2uUAAOKYw5jofrXbtGnT9K9//Usffvhhk8eNMfL5fHrwwQf10EMPSZICgYA8Ho9effVVjRo16rz/RjAYlNvtViAQkMvlimb5AIA20NzX8ahPUu+++6769++v2267TV26dFHfvn318ssvh4/v3r1bfr9fhYWF4X1ut1v5+fmqrKxs8poNDQ0KBoMRGwAg8UW9SX3++eeaO3euunXrpvfff1/33nuvHnjgAb322muSJL/fL0nyeDwRv+fxeMLHzlReXi632x3ecnJyol02AMBCUW9SoVBI1157rZ566in17dtXEyZM0Pjx4zVv3rwWX7O0tFSBQCC81dbWRrFiAICtot6ksrOz1bNnz4h9PXr00N69eyVJXq9XklRXVxdxTl1dXfjYmZxOp1wuV8QGAEh8UW9SgwcPVk1NTcS+nTt36vLLL5ck5eXlyev1qqKiInw8GAyqqqpKBQUF0S4HABDHUqN9wcmTJ2vQoEF66qmndPvtt2vDhg166aWX9NJLL0mSHA6HSkpK9OSTT6pbt27Ky8tTWVmZfD6fRowYEe1yAABxLOpNasCAAVqyZIlKS0v1xz/+UXl5eXruuedUVFQUPueRRx7RkSNHNGHCBNXX1+v666/XihUr1KFDh2iXAwCIY1F/Tqot8JwUAMS3mD0nBQBAtNCkAADWokkBAKwV9RsnAESHw3Hqv+PvnWMgOpikAADWokkBAKxF3AdY6vSIj+gPyYpJCgBgLZoUAMBaxH1AHCD6Q7JikgIAWIsmBQCwFnEfEGeI/pBMmKQAANaiSQEArEXcB8SxpqI/Yj8kEiYpAIC1mKSABHFyguJmCiQSJikAgLVoUgAAaxH3AQmG56iQSJikAADWokkBAKxF3AckMKI/xDsmKQCAtWhSAABrEfcBSYLoD/GISQoAYC2aFADAWsR9QBIi+kO8YJICAFiLJgUAsBZxH5DkiP5gMyYpAIC1aFIAAGsR9wEII/qDbZikAADWokkBAKxF3AegSUR/sAGTFADAWjQpAIC1iPsAnFdT0R+xH9oCkxQAwFpMUgAuyMkJipsp0BaYpAAA1qJJAQCsRdwHoEV4jgptIeqT1IkTJ1RWVqa8vDylp6fryiuv1BNPPCFz2l+uMUbTp09Xdna20tPTVVhYqF27dkW7FABAnIt6k3r66ac1d+5cvfDCC/r000/19NNPa9asWXr++efD58yaNUuzZ8/WvHnzVFVVpY4dO2ro0KE6duxYtMsBAMQxhzHRHc5/+ctfyuPx6JVXXgnvGzlypNLT0/X666/LGCOfz6cHH3xQDz30kCQpEAjI4/Ho1Vdf1ahRo877bwSDQbndbgUCAblcrmiWD+AiEf2hOZr7Oh71SWrQoEGqqKjQzp07JUkfffSR1q1bpxtuuEGStHv3bvn9fhUWFoZ/x+12Kz8/X5WVlU1es6GhQcFgMGIDACS+qN84MW3aNAWDQXXv3l3t2rXTiRMnNGPGDBUVFUmS/H6/JMnj8UT8nsfjCR87U3l5uR5//PFolwoAsFzUJ6m33npLCxYs0MKFC7V582a99tpr+tOf/qTXXnutxdcsLS1VIBAIb7W1tVGsGEA0GXNqczhObUBLRH2SevjhhzVt2rTwe0u9e/fWnj17VF5erjFjxsjr9UqS6urqlJ2dHf69uro6XXPNNU1e0+l0yul0RrtUAIDloj5JHT16VCkpkZdt166dQqGQJCkvL09er1cVFRXh48FgUFVVVSooKIh2OQCAOBb1SWr48OGaMWOGcnNzddVVV2nLli169tlndffdd0uSHA6HSkpK9OSTT6pbt27Ky8tTWVmZfD6fRowYEe1yAMQQD/ziYkW9ST3//PMqKyvTfffdpwMHDsjn8+l3v/udpk+fHj7nkUce0ZEjRzRhwgTV19fr+uuv14oVK9ShQ4dolwMAiGNRf06qLfCcFBB/mKRwuua+jvPZfQDaBNEfWoJPQQcAWIsmBQCwFnEfgDZH9IfmYpICAFiLJgUAsBZxH4CYIvrDuTBJAQCsRZMCAFiLuA+ANZqK/oj9khuTFADAWkxSAKx0coLiZorkxiQFALAWTQoAYC3iPgBW4zmq5MYkBQCwFk0KAGAt4j4AcYPoL/kwSQEArEWTAgBYi7gPQFwi+ksOTFIAAGvRpAAA1iLuAxD3iP4SF5MUAMBaNCkAgLWI+wAkFKK/xMIkBQCwFk0KAGAt4j4ACYvoL/4xSQEArEWTAgBYi7gPQFIg+otPTFIAAGvRpAAA1iLuA5B0mor+iP3sxCQFALAWkxSApHZyguJmCjsxSQEArEWTAgBYi7gPAMRzVLZikgIAWIsmBQCwFnEfAJyB6M8eTFIAAGvRpAAA1iLuA4BzIPqLrQuepNauXavhw4fL5/PJ4XBo6dKlEceNMZo+fbqys7OVnp6uwsJC7dq1K+KcgwcPqqioSC6XSxkZGRo3bpwOHz58UQsBACSeC25SR44cUZ8+fTRnzpwmj8+aNUuzZ8/WvHnzVFVVpY4dO2ro0KE6duxY+JyioiJ98sknWrlypZYvX661a9dqwoQJLV8FACAxmYsgySxZsiT8cygUMl6v1zzzzDPhffX19cbpdJo33njDGGPM9u3bjSSzcePG8DnvvfeecTgcZt++fc36dwOBgJFkAoHAxZQPAC32XeD33YYL19zX8ajeOLF79275/X4VFhaG97ndbuXn56uyslKSVFlZqYyMDPXv3z98TmFhoVJSUlRVVdXkdRsaGhQMBiM2AEDii2qT8vv9kiSPxxOx3+PxhI/5/X516dIl4nhqaqoyMzPD55ypvLxcbrc7vOXk5ESzbACApeLiFvTS0lIFAoHwVltbG+uSACS50wM/h+PUhuiKapPyer2SpLq6uoj9dXV14WNer1cHDhyIOP7tt9/q4MGD4XPO5HQ65XK5IjYAQOKLapPKy8uT1+tVRUVFeF8wGFRVVZUKCgokSQUFBaqvr1d1dXX4nFWrVikUCik/Pz+a5QAA4twFP8x7+PBhffbZZ+Gfd+/era1btyozM1O5ubkqKSnRk08+qW7duikvL09lZWXy+XwaMWKEJKlHjx4aNmyYxo8fr3nz5qmxsVETJ07UqFGj5PP5orYwAGgrPPDbii70tsHVq1cbSd/bxowZY4z57jb0srIy4/F4jNPpNEOGDDE1NTUR1/j666/N6NGjTadOnYzL5TJjx441hw4divqtiwDQ1rg1vXma+zruMCb+en0wGJTb7VYgEOD9KQBWYZJqnua+jvPZfQAQRUR/0RUXt6ADAJITTQoAYC3iPgBoJUR/F49JCgBgLZoUAMBaxH0A0Aaaiv6I/c6PSQoAYC0mKQBoYycnKG6mOD8mKQCAtWhSAABrEfcBQIzwHNX5MUkBAKxFkwIAWIu4DwAsQPTXNCYpAIC1aFIAAGsR9wGAZYj+TmGSAgBYiyYFALAWcR8AWCzZoz8mKQCAtWhSAABrEfcBQJxIxuiPSQoAYC2aFADAWsR9ABCHkiX6Y5ICAFiLJgUAsBZxHwDEuUSO/pikAADWokkBAKxF3AcACaSp6C+eYz8mKQCAtZikACBBnZyg4vlmCiYpAIC1aFIAAGsR9wFAgovn56iYpAAA1qJJAQCsRdwHAEkk3qI/JikAgLVoUgAAaxH3AUCSiofoj0kKAGAtmhQAwFoX3KTWrl2r4cOHy+fzyeFwaOnSpeFjjY2Nmjp1qnr37q2OHTvK5/Ppzjvv1P79+yOucfDgQRUVFcnlcikjI0Pjxo3T4cOHL3oxAICWMebU5nCc2mLtgpvUkSNH1KdPH82ZM+d7x44eParNmzerrKxMmzdv1ttvv62amhrdfPPNEecVFRXpk08+0cqVK7V8+XKtXbtWEyZMaPkqAAAJyWFMy98iczgcWrJkiUaMGHHWczZu3KiBAwdqz549ys3N1aeffqqePXtq48aN6t+/vyRpxYoVuvHGG/Xf//5XPp/vvP9uMBiU2+1WIBCQy+VqafkAgCa0xU0UzX0db/X3pAKBgBwOhzIyMiRJlZWVysjICDcoSSosLFRKSoqqqqpauxwAwHnYFP216i3ox44d09SpUzV69Ohwp/T7/erSpUtkEampyszMlN/vb/I6DQ0NamhoCP8cDAZbr2gAgDVabZJqbGzU7bffLmOM5s6de1HXKi8vl9vtDm85OTlRqhIAYLNWaVInG9SePXu0cuXKiLzR6/XqwIEDEed/++23OnjwoLxeb5PXKy0tVSAQCG+1tbWtUTYA4Ayxjv6iHvedbFC7du3S6tWrlZWVFXG8oKBA9fX1qq6uVr9+/SRJq1atUigUUn5+fpPXdDqdcjqd0S4VAGC5C25Shw8f1meffRb+effu3dq6dasyMzOVnZ2tW2+9VZs3b9by5ct14sSJ8PtMmZmZSktLU48ePTRs2DCNHz9e8+bNU2NjoyZOnKhRo0Y1684+AEDyuOBb0D/44AP93//93/f2jxkzRn/4wx+Ul5fX5O+tXr1aP/vZzyR99zDvxIkTtWzZMqWkpGjkyJGaPXu2OnXq1KwauAUdAGLrYm9Tb+7r+EU9JxUrNCkAiK22alJ8dh8AwFp8VQcA4II19TUfrZHLMUkBAKzFJAUAuCgnJ6jW+Mw/JikAgLVoUgAAaxH3AQCioqmbKc7cf6GYpAAA1qJJAQCsRdwHAIi6aEV/TFIAAGvRpAAA1iLuAwC0qrNFf83BJAUAsFZcTlInv10kGAzGuBIAQMt89/p9vm+LissmdejQIUlSTk5OjCsBAFyMQ4cOye12n/V4XH7pYSgU0v79+2WMUW5urmpraxP2yw+DwaBycnISeo0S60w0ybDOZFij1HrrNMbo0KFD8vl8Skk5+ztPcTlJpaSkqGvXruG4z+VyJfQfiZQca5RYZ6JJhnUmwxql1lnnuSaok7hxAgBgLZoUAMBacd2knE6nHnvsMTmdzliX0mqSYY0S60w0ybDOZFijFPt1xuWNEwCA5BDXkxQAILHRpAAA1qJJAQCsRZMCAFgrbpvUnDlzdMUVV6hDhw7Kz8/Xhg0bYl3SRSkvL9eAAQPUuXNndenSRSNGjFBNTU3EOceOHVNxcbGysrLUqVMnjRw5UnV1dTGq+OLNnDlTDodDJSUl4X2JssZ9+/bpjjvuUFZWltLT09W7d29t2rQpfNwYo+nTpys7O1vp6ekqLCzUrl27YljxhTtx4oTKysqUl5en9PR0XXnllXriiSciPostHte5du1aDR8+XD6fTw6HQ0uXLo043pw1HTx4UEVFRXK5XMrIyNC4ceN0+PDhNlzFuZ1rjY2NjZo6dap69+6tjh07yufz6c4779T+/fsjrtFmazRxaNGiRSYtLc38/e9/N5988okZP368ycjIMHV1dbEurcWGDh1q5s+fb7Zt22a2bt1qbrzxRpObm2sOHz4cPueee+4xOTk5pqKiwmzatMlcd911ZtCgQTGsuuU2bNhgrrjiCnP11VebSZMmhfcnwhoPHjxoLr/8cnPXXXeZqqoq8/nnn5v333/ffPbZZ+FzZs6cadxut1m6dKn56KOPzM0332zy8vLMN998E8PKL8yMGTNMVlaWWb58udm9e7dZvHix6dSpk/nrX/8aPice1/mPf/zDPProo+btt982ksySJUsijjdnTcOGDTN9+vQx69evNx9++KH58Y9/bEaPHt3GKzm7c62xvr7eFBYWmjfffNPs2LHDVFZWmoEDB5p+/fpFXKOt1hiXTWrgwIGmuLg4/POJEyeMz+cz5eXlMawqug4cOGAkmTVr1hhjvvvDad++vVm8eHH4nE8//dRIMpWVlbEqs0UOHTpkunXrZlauXGl++tOfhptUoqxx6tSp5vrrrz/r8VAoZLxer3nmmWfC++rr643T6TRvvPFGW5QYFTfddJO5++67I/bdcsstpqioyBiTGOs88wW8OWvavn27kWQ2btwYPue9994zDofD7Nu3r81qb66mGvGZNmzYYCSZPXv2GGPado1xF/cdP35c1dXVKiwsDO9LSUlRYWGhKisrY1hZdAUCAUlSZmamJKm6ulqNjY0R6+7evbtyc3Pjbt3FxcW66aabItYiJc4a3333XfXv31+33XabunTpor59++rll18OH9+9e7f8fn/EOt1ut/Lz8+NqnYMGDVJFRYV27twpSfroo4+0bt063XDDDZISZ52na86aKisrlZGRof79+4fPKSwsVEpKiqqqqtq85mgIBAJyOBzKyMiQ1LZrjLsPmP3qq6904sQJeTyeiP0ej0c7duyIUVXRFQqFVFJSosGDB6tXr16SJL/fr7S0tPAfyUkej0d+vz8GVbbMokWLtHnzZm3cuPF7xxJljZ9//rnmzp2rKVOm6Pe//702btyoBx54QGlpaRozZkx4LU39DcfTOqdNm6ZgMKju3burXbt2OnHihGbMmKGioiJJSph1nq45a/L7/erSpUvE8dTUVGVmZsbluo8dO6apU6dq9OjR4Q+Ybcs1xl2TSgbFxcXatm2b1q1bF+tSoqq2tlaTJk3SypUr1aFDh1iX02pCoZD69++vp556SpLUt29fbdu2TfPmzdOYMWNiXF30vPXWW1qwYIEWLlyoq666Slu3blVJSYl8Pl9CrTOZNTY26vbbb5cxRnPnzo1JDXEX91122WVq167d9+74qqurk9frjVFV0TNx4kQtX75cq1evVteuXcP7vV6vjh8/rvr6+ojz42nd1dXVOnDggK699lqlpqYqNTVVa9as0ezZs5WamiqPxxP3a5Sk7Oxs9ezZM2Jfjx49tHfvXkkKryXe/4YffvhhTZs2TaNGjVLv3r3129/+VpMnT1Z5ebmkxFnn6ZqzJq/XqwMHDkQc//bbb3Xw4MG4WvfJBrVnzx6tXLky4ms62nKNcdek0tLS1K9fP1VUVIT3hUIhVVRUqKCgIIaVXRxjjCZOnKglS5Zo1apVysvLizjer18/tW/fPmLdNTU12rt3b9yse8iQIfr444+1devW8Na/f38VFRWF/zve1yhJgwcP/t7jAzt37tTll18uScrLy5PX641YZzAYVFVVVVyt8+jRo9/7srp27dopFApJSpx1nq45ayooKFB9fb2qq6vD56xatUqhUEj5+fltXnNLnGxQu3bt0j//+U9lZWVFHG/TNUb1Now2smjRIuN0Os2rr75qtm/fbiZMmGAyMjKM3++PdWktdu+99xq3220++OAD88UXX4S3o0ePhs+55557TG5urlm1apXZtGmTKSgoMAUFBTGs+uKdfnefMYmxxg0bNpjU1FQzY8YMs2vXLrNgwQJzySWXmNdffz18zsyZM01GRoZ55513zH/+8x/zq1/9yvpbs880ZswY88Mf/jB8C/rbb79tLrvsMvPII4+Ez4nHdR46dMhs2bLFbNmyxUgyzz77rNmyZUv4zrbmrGnYsGGmb9++pqqqyqxbt85069bNqlvQz7XG48ePm5tvvtl07drVbN26NeL1qKGhIXyNtlpjXDYpY4x5/vnnTW5urklLSzMDBw4069evj3VJF0VSk9v8+fPD53zzzTfmvvvuM5deeqm55JJLzK9//WvzxRdfxK7oKDizSSXKGpctW2Z69eplnE6n6d69u3nppZcijodCIVNWVmY8Ho9xOp1myJAhpqamJkbVtkwwGDSTJk0yubm5pkOHDuZHP/qRefTRRyNeyOJxnatXr27yf4tjxowxxjRvTV9//bUZPXq06dSpk3G5XGbs2LHm0KFDMVhN0861xt27d5/19Wj16tXha7TVGvmqDgCAteLuPSkAQPKgSQEArEWTAgBYiyYFALAWTQoAYC2aFADAWjQpAIC1aFIAAGvRpAAA1qJJAQCsRZMCAFiLJgUAsNb/AwpbWaTT9aupAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "v_up_weight = model.model.layers[0].self_attn.v_up_proj.weight.data.to(\"cpu\")\n",
    "plt.imshow(v_up_weight.T@v_up_weight, cmap=cmap, interpolation='none')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transmla",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
